import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T


def test_unroll_with_step():

    @T.prim_func
    def main(A_ptr: T.handle):
        A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)

        for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
            for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
                for i in T.unroll(0, 16, step=4):
                    A[0, i] = 1.0

    kernel = tilelang.compile(main, target="cuda")
    assert "#pragma unroll" in kernel.get_kernel_source()


def test_unroll_with_unroll_factor():

    @T.prim_func
    def main(A_ptr: T.handle):
        A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)

        for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
            for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
                for i in T.unroll(0, 16, unroll_factor=4):
                    A[0, i] = 1.0

    kernel = tilelang.compile(main, target="cuda")
    assert "#pragma unroll 4" in kernel.get_kernel_source()


if __name__ == "__main__":
    tilelang.testing.main()
